{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "63ff6846",
   "metadata": {},
   "source": [
    "## Boundless DAS"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3ae11b28",
   "metadata": {},
   "outputs": [],
   "source": [
    "__author__ = \"Zhengxuan Wu\"\n",
    "__version__ = \"10/05/2023\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7d898fce",
   "metadata": {},
   "source": [
    "### Overview\n",
    "\n",
    "This tutorial aims to reproduce one key result of [the Boundless DAS paper](https://arxiv.org/pdf/2305.08809). It uses the same pricing tag dataset as in the paper. Additionally, it focuses on finding alignment for the left boundary check only. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8af5dff0",
   "metadata": {},
   "source": [
    "### Set-up"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "3e3c09e2",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2024-01-11 01:35:34,365] [INFO] [real_accelerator.py:158:get_accelerator] Setting ds_accelerator to cuda (auto detect)\n"
     ]
    }
   ],
   "source": [
    "try:\n",
    "    # This library is our indicator that the required installs\n",
    "    # need to be done.\n",
    "    import pyvene\n",
    "\n",
    "except ModuleNotFoundError:\n",
    "    !pip install git+https://github.com/stanfordnlp/pyvene.git"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "9a39c2b3",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from tqdm import tqdm, trange\n",
    "from datasets import Dataset\n",
    "from torch.utils.data import DataLoader\n",
    "from transformers import get_linear_schedule_with_warmup\n",
    "from torch.nn import CrossEntropyLoss\n",
    "from tutorial_price_tagging_utils import (\n",
    "    factual_sampler,\n",
    "    bound_alignment_sampler,\n",
    "    lower_bound_alignment_example_sampler,\n",
    ")\n",
    "\n",
    "from pyvene import (\n",
    "    IntervenableModel,\n",
    "    BoundlessRotatedSpaceIntervention,\n",
    "    RepresentationConfig,\n",
    "    IntervenableConfig,\n",
    ")\n",
    "from pyvene import create_llama\n",
    "from pyvene import set_seed, count_parameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "970a8f7d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7ae4a40cfb8d44a38578fd45815ab319",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading config.json:   0%|          | 0.00/550 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1220a638dba243d49cc5f951cd4d9364",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer_config.json:   0%|          | 0.00/726 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "d9fd255929ed487b8440ed3dae85df58",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer.model:   0%|          | 0.00/500k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "37973ed21b144db99a242a249ffba7ea",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading added_tokens.json:   0%|          | 0.00/21.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "307b905f058e48f19eab80f03dcc2496",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading (…)cial_tokens_map.json:   0%|          | 0.00/435 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "53793dbe43e54aa694d2cc0aff7b9a1d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Downloading tokenizer.json:   0%|          | 0.00/1.84M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama.LlamaTokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565\n",
      "normalizer.cc(51) LOG(INFO) precompiled_charsmap is empty. use identity normalization.\n",
      "Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cbf4bf617f9a4099bf3525e0c5334ed6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Loading checkpoint shards:   0%|          | 0/34 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "loaded model\n"
     ]
    }
   ],
   "source": [
    "config, tokenizer, llama = create_llama()\n",
    "_ = llama.to(\"cuda\")  # single gpu\n",
    "_ = llama.eval()  # always no grad on the model"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "99f9a7da",
   "metadata": {},
   "source": [
    "### Factual performance of instruct-tuned LLaMA-7B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7f5d9c7",
   "metadata": {},
   "outputs": [],
   "source": [
    "raw_prealign = factual_sampler(tokenizer, 5000, game=\"pricing_tag\")\n",
    "prealign_dataset = Dataset.from_dict(\n",
    "    {\"input_ids\": raw_prealign[0], \"labels\": raw_prealign[1]}\n",
    ")\n",
    "prealign_dataset.set_format(\"torch\", columns=[\"input_ids\", \"labels\"])\n",
    "prealign_dataloader = DataLoader(prealign_dataset, batch_size=8)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "e4cb38ae",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 625/625 [00:48<00:00, 12.82it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: 0.92\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "total_count = 0\n",
    "correct_count = 0\n",
    "with torch.no_grad():\n",
    "    for step, inputs in enumerate(tqdm(prealign_dataloader)):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(llama.device)\n",
    "\n",
    "        # aligning forward!\n",
    "        outputs = llama(\n",
    "            input_ids=inputs[\"input_ids\"],\n",
    "            labels=inputs[\"labels\"],\n",
    "        )\n",
    "\n",
    "        actual_test_labels = inputs[\"labels\"][:, -1]\n",
    "        pred_test_labels = torch.argmax(outputs.logits[:, -1], dim=-1)\n",
    "\n",
    "        correct_labels = actual_test_labels == pred_test_labels\n",
    "\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "current_acc = round(correct_count / total_count, 2)\n",
    "print(f\"[WARNING: THIS NEEDS TO BE GOOD!] prealign task accuracy: {current_acc}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "59d30b97",
   "metadata": {},
   "source": [
    "### Create training dataset for our trainable intervention (Boundless DAS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "79caade7",
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(42)\n",
    "\n",
    "###################\n",
    "# data loaders\n",
    "###################\n",
    "raw_data = bound_alignment_sampler(\n",
    "    tokenizer, 10000, [lower_bound_alignment_example_sampler]\n",
    ")\n",
    "\n",
    "raw_train = (\n",
    "    raw_data[0][:8000],\n",
    "    raw_data[1][:8000],\n",
    "    raw_data[2][:8000],\n",
    "    raw_data[3][:8000],\n",
    ")\n",
    "raw_eval = (\n",
    "    raw_data[0][8000:9000],\n",
    "    raw_data[1][8000:9000],\n",
    "    raw_data[2][8000:9000],\n",
    "    raw_data[3][8000:9000],\n",
    ")\n",
    "raw_test = (\n",
    "    raw_data[0][9000:],\n",
    "    raw_data[1][9000:],\n",
    "    raw_data[2][9000:],\n",
    "    raw_data[3][9000:],\n",
    ")\n",
    "train_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_train[0],\n",
    "        \"source_input_ids\": raw_train[1],\n",
    "        \"labels\": raw_train[2],\n",
    "        \"intervention_ids\": raw_train[3],  # we will not use this field\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "train_dataloader = DataLoader(\n",
    "    train_dataset,\n",
    "    batch_size=16,\n",
    ")\n",
    "eval_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_eval[0],\n",
    "        \"source_input_ids\": raw_eval[1],\n",
    "        \"labels\": raw_eval[2],\n",
    "        \"intervention_ids\": raw_eval[3],  # we will not use this field\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "eval_dataloader = DataLoader(\n",
    "    eval_dataset,\n",
    "    batch_size=16,\n",
    ")\n",
    "test_dataset = Dataset.from_dict(\n",
    "    {\n",
    "        \"input_ids\": raw_test[0],\n",
    "        \"source_input_ids\": raw_test[1],\n",
    "        \"labels\": raw_test[2],\n",
    "        \"intervention_ids\": raw_test[3],  # we will not use this field\n",
    "    }\n",
    ").with_format(\"torch\")\n",
    "test_dataloader = DataLoader(\n",
    "    test_dataset,\n",
    "    batch_size=16,\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "47e83259",
   "metadata": {},
   "source": [
    "### Boundless DAS on Position-aligned Tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "296d0a11",
   "metadata": {},
   "outputs": [],
   "source": [
    "def simple_boundless_das_position_config(model_type, intervention_type, layer):\n",
    "    config = IntervenableConfig(\n",
    "        model_type=model_type,\n",
    "        representations=[\n",
    "            RepresentationConfig(\n",
    "                layer,              # layer\n",
    "                intervention_type,  # intervention type\n",
    "            ),\n",
    "        ],\n",
    "        intervention_types=BoundlessRotatedSpaceIntervention,\n",
    "    )\n",
    "    return config\n",
    "\n",
    "\n",
    "config = simple_boundless_das_position_config(\n",
    "    type(llama), \"block_output\", 15\n",
    ")\n",
    "intervenable = IntervenableModel(config, llama)\n",
    "intervenable.set_device(\"cuda\")\n",
    "intervenable.disable_model_gradients()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "740e3724",
   "metadata": {},
   "outputs": [],
   "source": [
    "t_total = int(len(train_dataloader) * 3)\n",
    "warm_up_steps = 0.1 * t_total\n",
    "optimizer_params = []\n",
    "for k, v in intervenable.interventions.items():\n",
    "    optimizer_params += [{\"params\": v.rotate_layer.parameters()}]\n",
    "    optimizer_params += [{\"params\": v.intervention_boundaries, \"lr\": 1e-2}]\n",
    "optimizer = torch.optim.Adam(optimizer_params, lr=1e-3)\n",
    "scheduler = get_linear_schedule_with_warmup(\n",
    "    optimizer, num_warmup_steps=warm_up_steps, num_training_steps=t_total\n",
    ")\n",
    "\n",
    "\n",
    "# You can define your custom compute_metrics function.\n",
    "def compute_metrics(eval_preds, eval_labels):\n",
    "    total_count = 0\n",
    "    correct_count = 0\n",
    "    for eval_pred, eval_label in zip(eval_preds, eval_labels):\n",
    "        actual_test_labels = eval_label[:, -1]\n",
    "        pred_test_labels = torch.argmax(eval_pred[:, -1], dim=-1)\n",
    "        correct_labels = actual_test_labels == pred_test_labels\n",
    "        total_count += len(correct_labels)\n",
    "        correct_count += correct_labels.sum().tolist()\n",
    "    accuracy = round(correct_count / total_count, 2)\n",
    "    return {\"accuracy\": accuracy}\n",
    "\n",
    "\n",
    "epochs = 3\n",
    "gradient_accumulation_steps = 4\n",
    "total_step = 0\n",
    "target_total_step = len(train_dataloader) * epochs\n",
    "temperature_start = 50.0\n",
    "temperature_end = 0.1\n",
    "temperature_schedule = (\n",
    "    torch.linspace(temperature_start, temperature_end, target_total_step)\n",
    "    .to(torch.bfloat16)\n",
    "    .to(\"cuda\")\n",
    ")\n",
    "intervenable.set_temperature(temperature_schedule[total_step])\n",
    "\n",
    "\n",
    "def calculate_loss(logits, labels):\n",
    "    shift_logits = logits[..., :, :].contiguous()\n",
    "    shift_labels = labels[..., :].contiguous()\n",
    "    # Flatten the tokens\n",
    "    loss_fct = CrossEntropyLoss()\n",
    "    shift_logits = shift_logits.view(-1, intervenable.model_config.vocab_size)\n",
    "    shift_labels = shift_labels.view(-1)\n",
    "    # Enable model parallelism\n",
    "    shift_labels = shift_labels.to(shift_logits.device)\n",
    "    loss = loss_fct(shift_logits, shift_labels)\n",
    "\n",
    "    for k, v in intervenable.interventions.items():\n",
    "        boundary_loss = 1.0 * v.intervention_boundaries.sum()\n",
    "    loss += boundary_loss\n",
    "\n",
    "    return loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "5cc2a247",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "llama trainable parameters:  0\n",
      "intervention trainable parameters:  16777218\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Epoch: 0: 100%|██████████████████████████████████████████████████████████████████████████████████████| 500/500 [07:05<00:00,  1.18it/s, loss=0.5, acc=0.88]\n",
      "Epoch: 1: 100%|█████████████████████████████████████████████████████████████████████████████████████| 500/500 [07:58<00:00,  1.04it/s, loss=0.39, acc=0.94]\n",
      "Epoch: 2: 100%|█████████████████████████████████████████████████████████████████████████████████████| 500/500 [08:19<00:00,  1.00it/s, loss=0.35, acc=0.94]\n",
      "Epoch: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [23:23<00:00, 467.83s/it]\n"
     ]
    }
   ],
   "source": [
    "intervenable.model.train()  # train enables drop-off but no grads\n",
    "print(\"llama trainable parameters: \", count_parameters(intervenable.model))\n",
    "print(\"intervention trainable parameters: \", intervenable.count_parameters())\n",
    "train_iterator = trange(0, int(epochs), desc=\"Epoch\")\n",
    "for epoch in train_iterator:\n",
    "    epoch_iterator = tqdm(\n",
    "        train_dataloader, desc=f\"Epoch: {epoch}\", position=0, leave=True\n",
    "    )\n",
    "    for step, inputs in enumerate(epoch_iterator):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(\"cuda\")\n",
    "        b_s = inputs[\"input_ids\"].shape[0]\n",
    "        _, counterfactual_outputs = intervenable(\n",
    "            {\"input_ids\": inputs[\"input_ids\"]},\n",
    "            [{\"input_ids\": inputs[\"source_input_ids\"]}],\n",
    "            {\"sources->base\": 80},  # swap 80th token\n",
    "        )\n",
    "        eval_metrics = compute_metrics(\n",
    "            [counterfactual_outputs.logits], [inputs[\"labels\"]]\n",
    "        )\n",
    "\n",
    "        # loss and backprop\n",
    "        loss = calculate_loss(counterfactual_outputs.logits, inputs[\"labels\"])\n",
    "        loss_str = round(loss.item(), 2)\n",
    "        epoch_iterator.set_postfix({\"loss\": loss_str, \"acc\": eval_metrics[\"accuracy\"]})\n",
    "\n",
    "        if gradient_accumulation_steps > 1:\n",
    "            loss = loss / gradient_accumulation_steps\n",
    "        loss.backward()\n",
    "        if total_step % gradient_accumulation_steps == 0:\n",
    "            if not (gradient_accumulation_steps > 1 and total_step == 0):\n",
    "                optimizer.step()\n",
    "                scheduler.step()\n",
    "                intervenable.set_zero_grad()\n",
    "                intervenable.set_temperature(temperature_schedule[total_step])\n",
    "        total_step += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3323b113",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Test: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 63/63 [00:45<00:00,  1.38it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'accuracy': 0.96}\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "# evaluation on the test set\n",
    "eval_labels = []\n",
    "eval_preds = []\n",
    "with torch.no_grad():\n",
    "    epoch_iterator = tqdm(test_dataloader, desc=f\"Test\")\n",
    "    for step, inputs in enumerate(epoch_iterator):\n",
    "        for k, v in inputs.items():\n",
    "            if v is not None and isinstance(v, torch.Tensor):\n",
    "                inputs[k] = v.to(\"cuda\")\n",
    "        b_s = inputs[\"input_ids\"].shape[0]\n",
    "        _, counterfactual_outputs = intervenable(\n",
    "            {\"input_ids\": inputs[\"input_ids\"]},\n",
    "            [{\"input_ids\": inputs[\"source_input_ids\"]}],\n",
    "            {\"sources->base\": 80},  # swap 80th token\n",
    "        )\n",
    "        eval_labels += [inputs[\"labels\"]]\n",
    "        eval_preds += [counterfactual_outputs.logits]\n",
    "eval_metrics = compute_metrics(eval_preds, eval_labels)\n",
    "print(eval_metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9fbd6296",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
